import math
import matplotlib.pyplot as plt
import scalevi.utils.utils_experimenter as utils_experimenter
import scalevi.utils.utils as utils
import plotnine as p9
import pandas as pd
import jax.numpy as np 
from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)
import jax

if (utils.on_worker()
    | utils_experimenter.on_server()):
    import matplotlib as mpl
    mpl.use("Agg")


def get_plot_data(rez, param_name, transform=None, name=None, smooth=True):
    plot_data = []
    if isinstance(rez[param_name], dict):
        for k, param in rez[param_name].items():
            if "L" in k:
                if transform:
                    param = transform(param)
                param = param.reshape(param.shape[0], -1)
                for i in range(param.shape[1]):
                    d = param[:, i]
                    l = k+"_"+str(name)+"_"+str(i) if name else k+"_"+str(i)
                    plot_data.append([d, l])
    elif isinstance(rez[param_name], np.ndarray):
        param = rez[param_name]
        if transform:
            param = transform(param)
        if name:
            param_name = param_name+"_"+name
        plot_data.append(
            [param, param_name])

        if smooth:
            plot_data.append(
                [utils.smooth(param, conservative=True), param_name+"_smoothed"])
    else:
        raise NotImplementedError
    return plot_data

def plot_(*plot_data, ax=None, range_samples=None):
    plotter = ax if ax else plt
    cmap = plt.get_cmap('brg')
    colors = [cmap(i) for i in np.linspace(0, 1, 10)]
    # plotter.xaxis.grid(True, which='minor')
    # For the minor ticks, use no labels; default NullFormatter.
    # plotter.xaxis.set_minor_locator(AutoMinorLocator())
    # plotter.tick_params(which='minor', length=4, color='black')
    # plotter.tick_params(which='major', length=6, color='black')
    # ax.xaxis.set_minor_locator(AutoMinorLocator())
    def _range_to_slices(_range, n):
        first = _range[0]
        last = _range[1]
        if _range[1] is None:
            last = n
        return ([first], [last])

    for i, (d, l) in enumerate(plot_data):
        offset = 0
        if range_samples: 
            d = jax.lax.slice(d, *_range_to_slices(range_samples, len(d)))
            offset=range_samples[0]
        plotter.plot(offset+np.arange(len(d)), d, linewidth=2, label=l, color = colors[i], alpha = 0.5)

def _plot(*plot_data, ax=None, alpha = 0.5, range_samples=None):
    plotter = ax if ax else plt
    cmap = plt.get_cmap('jet')
    colors = [cmap(i) for i in np.linspace(0, 1, len(plot_data))]
    plotter.xaxis.grid(True, which='minor')
    # For the minor ticks, use no labels; default NullFormatter.
    plotter.xaxis.set_minor_locator(AutoMinorLocator())
    plotter.tick_params(which='minor', length=4, color='black')
    plotter.tick_params(which='major', length=6, color='black')
    # ax.xaxis.set_minor_locator(AutoMinorLocator())
    def _range_to_slices(_range, n):
        first = _range[0]
        last = _range[1]
        if _range[1] is None:
            last = n
        return ([first], [last])
    for i, (d, l) in enumerate(plot_data):
        offset = 0
        if range_samples: 
            d = jax.lax.slice(d, *_range_to_slices(range_samples, len(d)))
            offset=range_samples[0]
        plotter.plot(offset+np.arange(len(d)), d, label=l, color = colors[i], alpha = 0.5)

    # if ax is None:
    #     _plt = plt
    # else:
    #     _plt = ax

    # for d, l in plot_data:
    #     _plt.plot(d, label=l, alpha=alpha)

def plotter(rez, **kwargs):
    plt.style.use("seaborn")

    fig, ax1 = plt.subplots(1, 1)
    # fig, (ax1, ax2, ax3) = plt.subplots(3, 1)
    # # make a little extra space between the subplots
    fig.subplots_adjust(hspace=0.5)

    plot_data = get_plot_data(rez, 'value', **kwargs)
    _plot(*plot_data, ax = ax1)
    ax1.legend(ncol=3)
    ax1.title.set_text('Training ELBO')

    # plot_data = get_plot_data(rez, 'grad', **kwargs)
    # _plot(*plot_data, ax = ax2)
    # ax2.legend(ncol=3)
    # ax2.title.set_text('Gradients')

    # plot_data = get_plot_data(rez, 'param', **kwargs)
    # _plot(*plot_data, ax = ax3)
    # ax3.legend(ncol=3)
    # ax3.title.set_text('Parameters')
    # plt.legend()
    plt.show()

def plot_trace(df , A , B , x, y , k ):
    def estats_enc_map(x):
        if x in [[32]]:
            return "Shallow"
        elif x in [[32]*3, [64]*5, [32]*5, [32]*10]:
            return "Deep"
        elif x in [[32, 64, 128], [32, 64, 128, 256], [32, 64, 128, 256, 512], [32, 32, 64, 64, 128, 128] ]:
            return "Progressive"
        else :
            return "NA"

    def estats_amor_map(x):
        if x in [[32]]:
            return "Shallow"
        elif x in [[32]*3, [32]*5]:
            return "Medium Deep"
        elif x in [[32]*8, [32]*10]:
            return "Deep"
        else :
            return "NA"

    def fstats_map(x):
        if x in [[32]]:
            return "Shallow"
        elif x in [[32, 32, 32], [32, 32, 32, 32, 32]]:
            return "Medium Deep"
        elif x in [[32]*8, [32]*10]:
            return "Deep"
        elif x is None :
            return "NA"
        else :
            raise ValueError
    
    def bilinear_out_map(x):
        if (x is None) or (math.isnan(x)): 
            return "NA"
        if x == 65: return "$D + D^2$"
        if x == 20: return "$D + D$"

    def encoding_length_map(x):
        if x['bilinear_encoding_type'] in ["orth", "one-hot"]:
            return "NA"
        elif x['bilinear_encoding_out_dim'] is None: 
            return "NA"
        else : 
            return x['bilinear_encoding_out_dim']

    print(f"Shape of DataFrame received: {df.shape}")
    for c in filter(lambda x : "n_iter" in x, df.columns):
        df[c] = pd.to_numeric(df[c], downcast="integer")

    # df['flow used'] = df['bilinear_flow_n_layers'].apply(check_use)
    # df['bilinear mean only'] = df['bilinear_mean_only'].apply(check_use)
    df['bilinear_out_dim'] = df["bilinear_out_dim"].apply(bilinear_out_map)
    df['bilinear_encoding_type'] = df["bilinear_encoding_type"].apply(lambda x : "NA" if x is None else x)
    df['bilinear_encoding_out_dim'] = df.apply(encoding_length_map, axis = 1)
    df['vi_family_updated']=df['vi_family']+df['bilinear mean only'].apply(lambda x : " mean only " if x == "Yes" else "")
    
    df['identifier']=(df['vi_family_updated']+"; C = "+df['bilinear_out_dim'].astype(str)
        +"; Batch size: "+df['dsgd_batch_size'].astype(str))
    cutoff = df.groupby(['identifier'], sort = False)[y].max().min() - 10
    # c1, c2 = (df.groupby(["identifier", 'use_toy_data'], sort = False)[y].max()).groupby('use_toy_data').min().values
    # df[y] = df.apply(lambda x : onp.maximum(x[y], c1 -20) if x['use_toy_data'] == 0 else onp.maximum(x[y], c2 -20), axis = 1)
    df[y] = df[y].where(df[y]>cutoff, cutoff)
    name = "No. of users: "+str(df['N_leaves'][0]) + "; Batch size: "+str(df['minibatch_size'][0])
    print(f"Shape of the final DataFrame before plotting: {df.shape}")

    g = (p9.ggplot(df)
         + p9.geom_line(mapping=p9.aes(x=str(x),
                                        y=str(y),
                                        color='factor(identifier)'),
                        na_rm="True",
                        alpha=0.65) )

    if (A is not None) & (B is not None):
        g = (g 
            + p9.facet_grid(str(A) +'~'+ str(B)))
    elif (A is not None) & (B is None):
        g = (g 
            + p9.facet_wrap(str(A), labeller = 'label_both'))
    elif (A is None) & (B is not None):
        g = (g 
            + p9.facet_wrap(str(B), labeller = 'label_both'))
    else:
        pass
    g = (g
        + p9.theme(
                legend_position='bottom',
                figure_size = (14,6),
                dpi = 400,
                legend_box=p9.element_blank(),
                legend_box_margin = 0)
        + p9.labs(
                color='VI method',
                x='# of iterations',
                y = "Training ELBO")
        + p9.ggtitle(name)
        + p9.guides(color=p9.guide_legend(nrow = 2)))
    return g

def list_to_subplot_data(plot_lists):
    def get_column_id(i, label):
        if any(n in label for n in ["N: 159978", "N: 100000"]):
            return 2
        elif any(n in label for n in ["N: 1600", "N: 1000"]):
            return 1
        elif any(n in label for n in ["N: 16", "N: 10"]):
            return 0
        else:
            raise ValueError

    def get_lower_cutoff(i, label):
        if any(n in label for n in ["N: 159978", "N: 100000"]):
            return 2
        elif any(n in label for n in ["N: 1600", "N: 1000"]):
            return 1
        elif any(n in label for n in ["N: 16", "N: 10"]):
            return 0
        else:
            raise ValueError

    def get_upper_cutoff(i, label):
        if any(n in label for n in ["N: 159978", "N: 100000"]):
            return 2
        elif any(n in label for n in ["N: 1600", "N: 1000"]):
            return 1
        elif any(n in label for n in ["N: 16", "N: 10"]):
            return 0
        else:
            raise ValueError

    def get_axis_id(i, label):
        if any(n in label for n in ["N: 159978", "N: 100000"]):
            return 2
        elif any(n in label for n in ["N: 1600", "N: 1000"]):
            return 1
        elif any(n in label for n in ["N: 16", "N: 10"]):
            return 0
        else:
            raise ValueError

    def get_color_id(i, label):
        if any(q in label for q in ["q: GaussianWithSampleEval", 
                                    "q: BlockGaussianWithSampleEval", 
                                    "q: DiagonalWithSampleEval",]):
            return 0
        elif any(q in label for q in [  "q: BranchGaussianWithSampleEval", 
                                        "q: BranchBlockGaussianWithSampleEval", 
                                        "q: BranchDiagonalWithSampleEval"]):
            return 1
        elif any(q in label for q in [  "q: AmortizedBranchGaussianWithSampleEval", 
                                        "q: AmortizedBranchBlockGaussianWithSampleEval",
                                        "q: AmortizedBranchDiagonalWithSampleEval",]):
            return 2
        else:
            raise ValueError
 
    def get_alpha_id(i, label):
        if any(n in label for n in ["N: 159978", "N: 100000"]):
            return 2
        elif any(n in label for n in ["N: 1600", "N: 1000"]):
            return 1
        elif any(n in label for n in ["N: 16", "N: 10"]):
            return 0
        else:
            raise ValueError

    plot_data = {}
    
    for i, plot in enumerate(plot_lists):
        plot_data[i] = {
            "data": plot[0],
            "label": plot[1],
            "column_id": get_column_id(i, plot[1]),
            "lower_cutoff": get_lower_cutoff(i, plot[1]),
            "upper_cutoff": get_upper_cutoff(i, plot[1]), 
            "color": get_color_id(i, plot[1]), 
            "alpha": get_alpha_id(i, plot[1]), 
            "axis_id": get_axis_id(i, plot[1])
        }

    return plot_data

def plot_from_lists(
    plot_fnames, 
    list_to_subplot_data = None,
    smoothing_alphas = None, 
    lower_y_limits = None, 
    upper_y_limits = None,
    upper_x_limits = None,
    lower_y_cutoff = None, 
    upper_y_cutoff = None,
    titles = None,
    range_of_samples = (0, None),
    subplot_shape = (1, 3),
    colors = None,
    axis_colors = None,
    hlines = None,
    ):
    import matplotlib.pyplot as plt
    import jax.numpy as np
    plt.style.use("ggplot")
    plt.rcParams["mathtext.fontset"] = "cm"
    tags = ['Dense', "Block Diagonal", "Diagonal"]
    for zz, fname in enumerate(plot_fnames):
        plot_lists = utils.load_objects(fname, True)
        plot_data = list_to_subplot_data(plot_lists)
        fig, ax = plt.subplots(*subplot_shape, figsize=(10,16), dpi= 100)
        # print(ax)
        # exit()
        for k, v in plot_data.items():
            print(v['alpha'])
            smoothed_d = utils.smooth(v['data'], smoothing_alphas[v['alpha']], True)
            smoothed_d = np.where(smoothed_d>lower_y_cutoff[v['lower_cutoff']], smoothed_d, np.where(np.isnan(smoothed_d), smoothed_d, lower_y_cutoff[v['lower_cutoff']]))
            smoothed_d = np.where(smoothed_d>upper_y_cutoff[v['upper_cutoff']], np.where(np.isnan(smoothed_d), smoothed_d, upper_y_cutoff[v['upper_cutoff']]), smoothed_d)
            v['data_to_plot'] = smoothed_d

            ax[v['axis_id']].plot( np.arange(len(v['data_to_plot'])), 
                        v['data_to_plot'], 
                        linewidth = 2, 
                        label = v['label'], 
                        color = colors[v['color']], 
                        alpha = 0.5)
        for i, ax_ in enumerate(ax):
            ax_.tick_params(axis='x', colors="black")
            ax_.tick_params(axis='y', colors="black")
            ax_.xaxis.set_ticks([0, upper_x_limits[i]//2, upper_x_limits[i]])
            ax_.set_xlim(0, upper_x_limits[i])
            ax_.set_ylim(
                lower_y_limits[i], 
                upper_y_limits[i])
            ax_.set_facecolor("white")
            ax_.spines['bottom'].set_color("black")
            ax_.spines['left'].set_color("black")
            if zz == 0:
                ax_.set_title(titles[i])
            # ax_.locator_params(tight=True, nbins=4)
        # ax[1].locator_params(axis='y', nbins=10)
        ax[0].locator_params(tight=True, nbins=4)
        ax[1].locator_params(tight=True, nbins=5)
        ax[2].locator_params(tight=True, nbins=4)

        ax[1].ticklabel_format(axis="y", style="sci", scilimits = (0, 0))
        fig.set_size_inches(10,2.25)

        if hlines is not None:
            for i in range(3):
                ax[i].hlines(hlines[i], 0, 200000, colors = "black", label = r"Marginal ($\log \, p(y\vert x)$)")
        if zz == 0:
            handles, labels = ax[0].get_legend_handles_labels()
            legend = fig.legend(handles, labels, loc = (0.15, 0.2), prop={"size":10})
            if len(legend.get_texts())>2:
                legend.get_texts()[0].set_text(r'Joint ($q_{\phi}^{\mathrm{Joint}}$)')
                legend.get_texts()[1].set_text(r'Branch ($q_{v, w}^{\mathrm{Branch}}$)')
                legend.get_texts()[2].set_text(r'Amort ($q_{v, u}^{\mathrm{Amort}}$)')
            else:
                legend.get_texts()[0].set_text(r'Branch ($q_{v, w}$^{\mathrm{Branch}})')
                legend.get_texts()[1].set_text(r'Amort ($q_{v, u}$^{\mathrm{Amort}})')
            legend.get_frame().set_color("white")
            legend.get_frame().set_alpha(0.6)
        else:
            plt.legend('',frameon=False)
        # if zz ==2:
        #     fig.text(0.5, 0.001, '# iterations', ha='center', fontsize=14)
        # if zz ==1:
        ax[0].set_ylabel('ELBO')
        ax[2].yaxis.set_label_position("right")
        ax[2].set_ylabel(f"({tags[zz]})")
        # fig.text(1, 0.5, tags[zz], va='center', rotation='vertical', fontsize=14)
        # plt.locator_params(axis='x', nbins=10)
        # plt.show()
        fig.tight_layout(pad = 0.1, w_pad = 0.1, h_pad = 0.1)
        plt.savefig(f"../../paper/neurips2020/figures/{fname[25:][:-4]}.pdf", pad_inches = 1)




